Tony Reina
27 JUN 2017
Neon has convolutional neural networks to predict the 10 classes from the CIFAR-10 dataset. However, I found it difficult to find code that would allow me to just display the images. So I wrote this tutorial.
Let's load the CIFAR-10 dataset using Neon's ModelZoo helper function CIFAR10.
CIFAR-10 is a labeled dataset of 60,0000 32x32 images collected by Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. The Neon data randomly selects 50,000 of the images for training and 10,000 for testing/validation.
There are 10 classes:
| Class # | Class Name |
|---|---|
| 0 | airplane |
| 1 | automobile |
| 2 | bird |
| 3 | cat |
| 4 | deer |
| 5 | dog |
| 6 | frog |
| 7 | horse |
| 8 | ship |
| 9 | truck |
The classes are mutually exclusive. For example, the images of trucks are large vehicles (fire trucks, dump trucks) rather than pickup trucks or SUVs (which could be considered automobiles).
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
In [1]:
from neon.data import CIFAR10 # Neon's helper function to download the CIFAR-10 data
from PIL import Image # The Python Image Library (PIL)
import numpy as np # Our old friend numpy
When you call CIFAR10's load_data() helper function, it will look for the CIFAR-10 data in the file path you've specified and download the file from the University of Toronto website if needed (the file is 163 MB).
According to the website, the archive contains the files data_batch_1, data_batch_2, ..., data_batch_5, as well as test_batch. Each of these files is a Python "pickled" object produced with cPickle.
These pickled batch files contain dictionaries with the images:
In [2]:
dataset = dict() # We'll create a dictionary for the training and testing/validation set images
In [3]:
pathToSaveData = './path' # The path where neon will download and extract the data files
In [4]:
cifar10 = CIFAR10(path=pathToSaveData, normalize=False, whiten=False, contrast_normalize=False)
dataset['train'], dataset['validation'], numClasses = cifar10.load_data()
The images are loaded into the 'train' and 'validation' dictionary keys. There are 50,000 train images and 10,000 validation images. The first element [0] is the image data. The second element [1] is the image label (class 0-9).
To display one of the images, we just need to reshape the data. We're given the 3x32x32 images as an unraveled 3,072 element vector. To display them properly, we need to reshape them to 3x32x32 and then move the 3 red-blue-green channels to the first dimension of the tensor.
In [5]:
import matplotlib.pyplot as plt
%matplotlib inline
def getImage(dataset=dataset, setName='train', index=0):
# The images are index 0 of the dictionary
# They are stored as a 3072 element vector so we need to reshape this into a tensor.
# The first dimension is the red/green/blue channel, the second is the pixel row, the third is the pixel column
im = dataset[setName][0][index].reshape(3,32,32)
# PIL and matplotlib want the red/green/blue channels last in the matrix. So we just need to rearrange
# the tensor to put that dimension last.
im = np.transpose(im, axes=[1, 2, 0]) # Put the 0-th dimension at the end
# Image are supposed to be unsigned 8-bit integers. If we keep the raw images, then
# this line is not needed. However, if we normalize or whiten the image, then the values become
# floats. So we need to convert them back to uint8s.
im = np.uint8(im)
classIndex = dataset[setName][1][index][0] # This is the class label (0-9)
classNames = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# Now use PIL's Image helper to turn our array into something that matplotlib will understand as an image.
return [classIndex, classNames[classIndex], Image.fromarray(im)]
In [6]:
idx, name, im = getImage(dataset, 'train', 1022) # Get image 1022 of the training data.
plt.imshow(im);
plt.title(name);
plt.axis('off');
In [7]:
idx, name, im = getImage(dataset, 'train', 8888) # Get image 8888 of the training data
plt.imshow(im);
plt.title(name);
plt.axis('off');
In [8]:
idx, name, im = getImage(dataset, 'train', 5002) # Get image 5002 of the training data
plt.imshow(im);
plt.title(name);
plt.axis('off');
In [9]:
idx, name, im = getImage(dataset, 'train', 10022) # Get image 10022 of the training data
plt.imshow(im);
plt.title(name);
plt.axis('off');
In [10]:
idx, name, im = getImage(dataset, 'train', 7022) # Get image 7022 of the training data
plt.imshow(im);
plt.title(name);
plt.axis('off');
In [11]:
idx, name, im = getImage(dataset, 'validation', 1022) # Get image 1022 of the validation data
plt.imshow(im);
plt.title(name);
plt.axis('off');
In [12]:
idx, name, im = getImage(dataset, 'validation', 1031) # Get image 1031 of the validation data
plt.imshow(im);
plt.title(name);
plt.axis('off');
In [13]:
idx, name, im = getImage(dataset, 'validation', 9135) # Get image 9135 of the validation data
plt.imshow(im);
plt.title(name);
plt.axis('off');
In [ ]:
In [ ]: